{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SQL Agent for Spider text-to-SQL benchmark"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook demonstrates a basic SQL agent that translates natural language questions into SQL queries."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Environment\n",
    "\n",
    "For this demo, we use a SQLite database environment based on a standard text-to-sql benchmark called [Spider](https://yale-lily.github.io/spider). The environment provides a gym-like interface and can be used as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading cached Spider dataset from /home/wangdazhang/.cache/spider\n",
      "Schema file not found for /home/wangdazhang/.cache/spider/spider/database/flight_4\n",
      "Schema file not found for /home/wangdazhang/.cache/spider/spider/database/small_bank_1\n",
      "Schema file not found for /home/wangdazhang/.cache/spider/spider/database/icfp_1\n",
      "Schema file not found for /home/wangdazhang/.cache/spider/spider/database/twitter_1\n",
      "Schema file not found for /home/wangdazhang/.cache/spider/spider/database/epinions_1\n",
      "Schema file not found for /home/wangdazhang/.cache/spider/spider/database/chinook_1\n",
      "Schema file not found for /home/wangdazhang/.cache/spider/spider/database/company_1\n"
     ]
    }
   ],
   "source": [
    "# %pip install spider-env\n",
    "import json\n",
    "import os\n",
    "from typing import Annotated, Dict\n",
    "\n",
    "from spider_env import SpiderEnv\n",
    "\n",
    "from autogen import ConversableAgent, UserProxyAgent, config_list_from_json\n",
    "\n",
    "gym = SpiderEnv()\n",
    "\n",
    "# Randomly select a question from Spider\n",
    "observation, info = gym.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Find the famous titles of artists that do not have any volume.\n"
     ]
    }
   ],
   "source": [
    "# The natural language question\n",
    "question = observation[\"instruction\"]\n",
    "print(question)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CREATE TABLE \"artist\" (\n",
      "\"Artist_ID\" int,\n",
      "\"Artist\" text,\n",
      "\"Age\" int,\n",
      "\"Famous_Title\" text,\n",
      "\"Famous_Release_date\" text,\n",
      "PRIMARY KEY (\"Artist_ID\")\n",
      ");\n",
      "CREATE TABLE \"volume\" (\n",
      "\"Volume_ID\" int,\n",
      "\"Volume_Issue\" text,\n",
      "\"Issue_Date\" text,\n",
      "\"Weeks_on_Top\" real,\n",
      "\"Song\" text,\n",
      "\"Artist_ID\" int,\n",
      "PRIMARY KEY (\"Volume_ID\"),\n",
      "FOREIGN KEY (\"Artist_ID\") REFERENCES \"artist\"(\"Artist_ID\")\n",
      ");\n",
      "CREATE TABLE \"music_festival\" (\n",
      "\"ID\" int,\n",
      "\"Music_Festival\" text,\n",
      "\"Date_of_ceremony\" text,\n",
      "\"Category\" text,\n",
      "\"Volume\" int,\n",
      "\"Result\" text,\n",
      "PRIMARY KEY (\"ID\"),\n",
      "FOREIGN KEY (\"Volume\") REFERENCES \"volume\"(\"Volume_ID\")\n",
      ");\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# The schema of the corresponding database\n",
    "schema = info[\"schema\"]\n",
    "print(schema)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Agent Implementation\n",
    "\n",
    "Using AutoGen, a SQL agent can be implemented with a ConversableAgent. The gym environment executes the generated SQL query and the agent can take execution results as feedback to improve its generation in multiple rounds of conversations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"AUTOGEN_USE_DOCKER\"] = \"False\"\n",
    "config_list = config_list_from_json(env_or_file=\"OAI_CONFIG_LIST\")\n",
    "\n",
    "\n",
    "def check_termination(msg: Dict):\n",
    "    if \"tool_responses\" not in msg:\n",
    "        return False\n",
    "    json_str = msg[\"tool_responses\"][0][\"content\"]\n",
    "    obj = json.loads(json_str)\n",
    "    return \"error\" not in obj or obj[\"error\"] is None and obj[\"reward\"] == 1\n",
    "\n",
    "\n",
    "sql_writer = ConversableAgent(\n",
    "    \"sql_writer\",\n",
    "    llm_config={\"config_list\": config_list},\n",
    "    system_message=\"You are good at writing SQL queries. Always respond with a function call to execute_sql().\",\n",
    "    is_termination_msg=check_termination,\n",
    ")\n",
    "user_proxy = UserProxyAgent(\"user_proxy\", human_input_mode=\"NEVER\", max_consecutive_auto_reply=5)\n",
    "\n",
    "\n",
    "@sql_writer.register_for_llm(description=\"Function for executing SQL query and returning a response\")\n",
    "@user_proxy.register_for_execution()\n",
    "def execute_sql(\n",
    "    reflection: Annotated[str, \"Think about what to do\"], sql: Annotated[str, \"SQL query\"]\n",
    ") -> Annotated[Dict[str, str], \"Dictionary with keys 'result' and 'error'\"]:\n",
    "    observation, reward, _, _, info = gym.step(sql)\n",
    "    error = observation[\"feedback\"][\"error\"]\n",
    "    if not error and reward == 0:\n",
    "        error = \"The SQL query returned an incorrect result\"\n",
    "    if error:\n",
    "        return {\n",
    "            \"error\": error,\n",
    "            \"wrong_result\": observation[\"feedback\"][\"result\"],\n",
    "            \"correct_result\": info[\"gold_result\"],\n",
    "        }\n",
    "    else:\n",
    "        return {\n",
    "            \"result\": observation[\"feedback\"][\"result\"],\n",
    "        }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The agent can then take as input the schema and the text question, and generate the SQL query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33muser_proxy\u001b[0m (to sql_writer):\n",
      "\n",
      "Below is the schema for a SQL database:\n",
      "CREATE TABLE \"artist\" (\n",
      "\"Artist_ID\" int,\n",
      "\"Artist\" text,\n",
      "\"Age\" int,\n",
      "\"Famous_Title\" text,\n",
      "\"Famous_Release_date\" text,\n",
      "PRIMARY KEY (\"Artist_ID\")\n",
      ");\n",
      "CREATE TABLE \"volume\" (\n",
      "\"Volume_ID\" int,\n",
      "\"Volume_Issue\" text,\n",
      "\"Issue_Date\" text,\n",
      "\"Weeks_on_Top\" real,\n",
      "\"Song\" text,\n",
      "\"Artist_ID\" int,\n",
      "PRIMARY KEY (\"Volume_ID\"),\n",
      "FOREIGN KEY (\"Artist_ID\") REFERENCES \"artist\"(\"Artist_ID\")\n",
      ");\n",
      "CREATE TABLE \"music_festival\" (\n",
      "\"ID\" int,\n",
      "\"Music_Festival\" text,\n",
      "\"Date_of_ceremony\" text,\n",
      "\"Category\" text,\n",
      "\"Volume\" int,\n",
      "\"Result\" text,\n",
      "PRIMARY KEY (\"ID\"),\n",
      "FOREIGN KEY (\"Volume\") REFERENCES \"volume\"(\"Volume_ID\")\n",
      ");\n",
      "\n",
      "Generate a SQL query to answer the following question:\n",
      "Find the famous titles of artists that do not have any volume.\n",
      "\n",
      "\n",
      "--------------------------------------------------------------------------------\n",
      "\u001b[31m\n",
      ">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
      "\u001b[33msql_writer\u001b[0m (to user_proxy):\n",
      "\n",
      "\u001b[32m***** Suggested tool Call (call_eAu0OEzS8l3QvN3jQSn4w0hJ): execute_sql *****\u001b[0m\n",
      "Arguments: \n",
      "{\"reflection\":\"Generating SQL to find famous titles of artists without any volume\",\"sql\":\"SELECT a.Artist, a.Famous_Title FROM artist a WHERE NOT EXISTS (SELECT 1 FROM volume v WHERE v.Artist_ID = a.Artist_ID)\"}\n",
      "\u001b[32m****************************************************************************\u001b[0m\n",
      "\n",
      "--------------------------------------------------------------------------------\n",
      "\u001b[35m\n",
      ">>>>>>>> EXECUTING FUNCTION execute_sql...\u001b[0m\n",
      "\u001b[33muser_proxy\u001b[0m (to sql_writer):\n",
      "\n",
      "\u001b[33muser_proxy\u001b[0m (to sql_writer):\n",
      "\n",
      "\u001b[32m***** Response from calling tool \"call_eAu0OEzS8l3QvN3jQSn4w0hJ\" *****\u001b[0m\n",
      "{\"error\": \"The SQL query returned an incorrect result\", \"wrong_result\": [[\"Ophiolatry\", \"Antievangelistical Process (re-release)\"], [\"Triumfall\", \"Antithesis of All Flesh\"]], \"correct_result\": [[\"Antievangelistical Process (re-release)\"], [\"Antithesis of All Flesh\"]]}\n",
      "\u001b[32m**********************************************************************\u001b[0m\n",
      "\n",
      "--------------------------------------------------------------------------------\n",
      "\u001b[31m\n",
      ">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
      "\u001b[33msql_writer\u001b[0m (to user_proxy):\n",
      "\n",
      "\u001b[32m***** Suggested tool Call (call_5LXoKqdZ17kPCOHJbbpSz2yk): execute_sql *****\u001b[0m\n",
      "Arguments: \n",
      "{\"reflection\":\"Adjusting SQL to only select famous titles and exclude artist names for artists without any volume.\",\"sql\":\"SELECT a.Famous_Title FROM artist a WHERE NOT EXISTS (SELECT 1 FROM volume v WHERE v.Artist_ID = a.Artist_ID)\"}\n",
      "\u001b[32m****************************************************************************\u001b[0m\n",
      "\n",
      "--------------------------------------------------------------------------------\n",
      "\u001b[35m\n",
      ">>>>>>>> EXECUTING FUNCTION execute_sql...\u001b[0m\n",
      "\u001b[33muser_proxy\u001b[0m (to sql_writer):\n",
      "\n",
      "\u001b[33muser_proxy\u001b[0m (to sql_writer):\n",
      "\n",
      "\u001b[32m***** Response from calling tool \"call_5LXoKqdZ17kPCOHJbbpSz2yk\" *****\u001b[0m\n",
      "{\"result\": [[\"Antievangelistical Process (re-release)\"], [\"Antithesis of All Flesh\"]]}\n",
      "\u001b[32m**********************************************************************\u001b[0m\n",
      "\n",
      "--------------------------------------------------------------------------------\n",
      "\u001b[31m\n",
      ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "message = f\"\"\"Below is the schema for a SQL database:\n",
    "{schema}\n",
    "Generate a SQL query to answer the following question:\n",
    "{question}\n",
    "\"\"\"\n",
    "\n",
    "user_proxy.initiate_chat(sql_writer, message=message)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
